{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "csv.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DweYe9FcbMK_"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "AVV2e0XKbJeX",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sUtoed20cRJJ"
      },
      "source": [
        "# 用 tf.data 加载 CSV 数据"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1ap_W4aQcgNT"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/load_data/csv\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 Tensorflow.org 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/tutorials/load_data/csv.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/tutorials/load_data/csv.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 Github 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/site/zh-cn/tutorials/load_data/csv.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载此 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z4x9SYyev8QZ",
        "colab_type": "text"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://www.tensorflow.org/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "C-3Xbt0FfGfs"
      },
      "source": [
        "这篇教程通过一个示例展示了怎样将 CSV 格式的数据加载进 `tf.data.Dataset`。\n",
        "\n",
        "这篇教程使用的是泰坦尼克号乘客的数据。模型会根据乘客的年龄、性别、票务舱和是否独自旅行等特征来预测乘客生还的可能性。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fgZ9gjmPfSnK"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "I4dwMQVQMQWD",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  # Colab only\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "    pass\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "baYFZMW_bJHh",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "import functools\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ncf5t6tgL5ZI",
        "colab": {}
      },
      "source": [
        "TRAIN_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\"\n",
        "TEST_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/eval.csv\"\n",
        "\n",
        "train_file_path = tf.keras.utils.get_file(\"train.csv\", TRAIN_DATA_URL)\n",
        "test_file_path = tf.keras.utils.get_file(\"eval.csv\", TEST_DATA_URL)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "4ONE94qulk6S",
        "colab": {}
      },
      "source": [
        "# 让 numpy 数据更易读。\n",
        "np.set_printoptions(precision=3, suppress=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Wuqj601Qw0Ml"
      },
      "source": [
        "## 加载数据\n",
        "\n",
        "开始的时候，我们通过打印 CSV 文件的前几行来了解文件的格式。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "54Dv7mCrf9Yw",
        "colab": {}
      },
      "source": [
        "!head {train_file_path}"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YOYKQKmMj3D6"
      },
      "source": [
        "正如你看到的那样，CSV 文件的每列都会有一个列名。dataset 的构造函数会自动识别这些列名。如果你使用的文件的第一行不包含列名，那么需要将列名通过字符串列表传给 `make_csv_dataset` 函数的 `column_names` 参数。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZS-bt1LvWn2x"
      },
      "source": [
        " \n",
        "\n",
        "\n",
        "\n",
        "```python\n",
        "\n",
        "CSV_COLUMNS = ['survived', 'sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']\n",
        "\n",
        "dataset = tf.data.experimental.make_csv_dataset(\n",
        "     ...,\n",
        "     column_names=CSV_COLUMNS,\n",
        "     ...)\n",
        "  \n",
        "```\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gZfhoX7bR9u4"
      },
      "source": [
        "这个示例使用了所有的列。如果你需要忽略数据集中的某些列，创建一个包含你需要使用的列的列表，然后传给构造器的（可选）参数 `select_columns`。\n",
        "\n",
        "```python\n",
        "\n",
        "dataset = tf.data.experimental.make_csv_dataset(\n",
        "  ...,\n",
        "  select_columns = columns_to_use, \n",
        "  ...)\n",
        "\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "67mfwr4v-mN_"
      },
      "source": [
        "对于包含模型需要预测的值的列是你需要显式指定的。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "iXROZm5f3V4E",
        "colab": {}
      },
      "source": [
        "LABEL_COLUMN = 'survived'\n",
        "LABELS = [0, 1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "t4N-plO4tDXd"
      },
      "source": [
        "现在从文件中读取 CSV 数据并且创建 dataset。\n",
        "\n",
        "(完整的文档，参考 `tf.data.experimental.make_csv_dataset`)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Co7UJ7gpNADC",
        "colab": {}
      },
      "source": [
        "def get_dataset(file_path):\n",
        "  dataset = tf.data.experimental.make_csv_dataset(\n",
        "      file_path,\n",
        "      batch_size=12, # 为了示例更容易展示，手动设置较小的值\n",
        "      label_name=LABEL_COLUMN,\n",
        "      na_value=\"?\",\n",
        "      num_epochs=1,\n",
        "      ignore_errors=True)\n",
        "  return dataset\n",
        "\n",
        "raw_train_data = get_dataset(train_file_path)\n",
        "raw_test_data = get_dataset(test_file_path)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vHUQFKoQI6G7"
      },
      "source": [
        "dataset 中的每个条目都是一个批次，用一个元组（*多个样本*，*多个标签*）表示。样本中的数据组织形式是以列为主的张量（而不是以行为主的张量），每条数据中包含的元素个数就是批次大小（这个示例中是 12）。\n",
        "\n",
        "阅读下面的示例有助于你的理解。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "qWtFYtwXIeuj",
        "colab": {}
      },
      "source": [
        "examples, labels = next(iter(raw_train_data)) # 第一个批次\n",
        "print(\"EXAMPLES: \\n\", examples, \"\\n\")\n",
        "print(\"LABELS: \\n\", labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9cryz31lxs3e"
      },
      "source": [
        "## 数据预处理"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tSyrkSQwYHKi"
      },
      "source": [
        "### 分类数据\n",
        "\n",
        "CSV 数据中的有些列是分类的列。也就是说，这些列只能在有限的集合中取值。\n",
        "\n",
        "使用 `tf.feature_column` API 创建一个 `tf.feature_column.indicator_column` 集合，每个 `tf.feature_column.indicator_column` 对应一个分类的列。\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "mWDniduKMw-C",
        "colab": {}
      },
      "source": [
        "CATEGORIES = {\n",
        "    'sex': ['male', 'female'],\n",
        "    'class' : ['First', 'Second', 'Third'],\n",
        "    'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],\n",
        "    'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],\n",
        "    'alone' : ['y', 'n']\n",
        "}\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kkxLdrsLwHPT",
        "colab": {}
      },
      "source": [
        "categorical_columns = []\n",
        "for feature, vocab in CATEGORIES.items():\n",
        "  cat_col = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "        key=feature, vocabulary_list=vocab)\n",
        "  categorical_columns.append(tf.feature_column.indicator_column(cat_col))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "H18CxpHY_Nma",
        "colab": {}
      },
      "source": [
        "# 你刚才创建的内容\n",
        "categorical_columns"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "R7-1QG99_1sN"
      },
      "source": [
        "这将是后续构建模型时处理输入数据的一部分。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9AsbaFmCeJtF"
      },
      "source": [
        "### 连续数据"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "o2maE8d2ijsq"
      },
      "source": [
        "连续数据需要标准化。\n",
        "\n",
        "写一个函数标准化这些值，然后将这些值改造成 2 维的张量。\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "REKqO_xHPNx0",
        "colab": {}
      },
      "source": [
        "def process_continuous_data(mean, data):\n",
        "  # 标准化数据\n",
        "  data = tf.cast(data, tf.float32) * 1/(2*mean)\n",
        "  return tf.reshape(data, [-1, 1])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VPsoMUgRCpUM"
      },
      "source": [
        "现在创建一个数值列的集合。`tf.feature_columns.numeric_column` API 会使用 `normalizer_fn` 参数。在传参的时候使用  [`functools.partial`](https://docs.python.org/3/library/functools.html#functools.partial)，`functools.partial` 由使用每个列的均值进行标准化的函数构成。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WKT1ASWpwH46",
        "colab": {}
      },
      "source": [
        "MEANS = {\n",
        "    'age' : 29.631308,\n",
        "    'n_siblings_spouses' : 0.545455,\n",
        "    'parch' : 0.379585,\n",
        "    'fare' : 34.385399\n",
        "}\n",
        "\n",
        "numerical_columns = []\n",
        "\n",
        "for feature in MEANS.keys():\n",
        "  num_col = tf.feature_column.numeric_column(feature, normalizer_fn=functools.partial(process_continuous_data, MEANS[feature]))\n",
        "  numerical_columns.append(num_col)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Bw0I35xRS57V",
        "colab": {}
      },
      "source": [
        "# 你刚才创建的内容。\n",
        "numerical_columns"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "M37oD2VcCO4R"
      },
      "source": [
        "这里使用标准化的方法需要提前知道每列的均值。如果需要计算连续的数据流的标准化的值可以使用 [TensorFlow Transform](https://www.tensorflow.org/tfx/transform/get_started)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kPWkC4_1l3IG"
      },
      "source": [
        "### 创建预处理层"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "R3QAjo1qD4p9"
      },
      "source": [
        "将这两个特征列的集合相加，并且传给 `tf.keras.layers.DenseFeatures` 从而创建一个进行预处理的输入层。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3-OYK7GnaH0r",
        "colab": {}
      },
      "source": [
        "preprocessing_layer = tf.keras.layers.DenseFeatures(categorical_columns+numerical_columns)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DlF_omQqtnOP"
      },
      "source": [
        "## 构建模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lQoFh16LxtT_"
      },
      "source": [
        "从 `preprocessing_layer` 开始构建 `tf.keras.Sequential`。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3mSGsHTFPvFo",
        "colab": {}
      },
      "source": [
        "model = tf.keras.Sequential([\n",
        "  preprocessing_layer,\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(1, activation='sigmoid'),\n",
        "])\n",
        "\n",
        "model.compile(\n",
        "    loss='binary_crossentropy',\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hPdtI2ie0lEZ"
      },
      "source": [
        "## 训练、评估和预测"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8gvw1RE9zXkD"
      },
      "source": [
        "现在可以实例化和训练模型。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "sW-4XlLeEQ2B",
        "colab": {}
      },
      "source": [
        "train_data = raw_train_data.shuffle(500)\n",
        "test_data = raw_test_data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Q_nm28IzNDTO",
        "colab": {}
      },
      "source": [
        "model.fit(train_data, epochs=20)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QyDMgBurzqQo"
      },
      "source": [
        "当模型训练完成的时候，你可以在测试集 `test_data` 上检查准确性。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "eB3R3ViVONOp",
        "colab": {}
      },
      "source": [
        "test_loss, test_accuracy = model.evaluate(test_data)\n",
        "\n",
        "print('\\n\\nTest Loss {}, Test Accuracy {}'.format(test_loss, test_accuracy))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sTrn_pD90gdJ"
      },
      "source": [
        "使用 `tf.keras.Model.predict` 推断一个批次或多个批次的标签。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Qwcx74F3ojqe",
        "colab": {}
      },
      "source": [
        "predictions = model.predict(test_data)\n",
        "\n",
        "# 显示部分结果\n",
        "for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):\n",
        "  print(\"Predicted survival: {:.2%}\".format(prediction[0]),\n",
        "        \" | Actual outcome: \",\n",
        "        (\"SURVIVED\" if bool(survived) else \"DIED\"))\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
